{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import json\n",
    "from keras.models import Model\n",
    "from keras.layers import Input\n",
    "from keras.layers.convolutional import Conv2D\n",
    "from keras.layers.pooling import MaxPooling2D, AveragePooling2D\n",
    "from keras.layers.normalization import BatchNormalization\n",
    "from keras import backend as K\n",
    "from collections import OrderedDict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def format_decimal(arr, places=6):\n",
    "    return [round(x * 10**places) / 10**places for x in arr]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "DATA = OrderedDict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### pipeline 7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "random_seed = 1007\n",
    "data_in_shape = (8, 8, 2)\n",
    "\n",
    "layers = [\n",
    "    Conv2D(4, (3,3), activation='relu', padding='valid', strides=(1,1), data_format='channels_last', use_bias=True),\n",
    "    MaxPooling2D(pool_size=(2,2), strides=(1,1), padding='valid', data_format='channels_last')\n",
    "]\n",
    "\n",
    "input_layer = Input(shape=data_in_shape)\n",
    "x = layers[0](input_layer)\n",
    "for layer in layers[1:-1]:\n",
    "    x = layer(x)\n",
    "output_layer = layers[-1](x)\n",
    "model = Model(inputs=input_layer, outputs=output_layer)\n",
    "\n",
    "np.random.seed(random_seed)\n",
    "data_in = 2 * np.random.random(data_in_shape) - 1\n",
    "\n",
    "# set weights to random (use seed for reproducibility)\n",
    "weights = []\n",
    "for i, w in enumerate(model.get_weights()):\n",
    "    np.random.seed(random_seed + i)\n",
    "    weights.append(2 * np.random.random(w.shape) - 1)\n",
    "model.set_weights(weights)\n",
    "\n",
    "result = model.predict(np.array([data_in]))\n",
    "data_out_shape = result[0].shape\n",
    "data_in_formatted = format_decimal(data_in.ravel().tolist())\n",
    "data_out_formatted = format_decimal(result[0].ravel().tolist())\n",
    "\n",
    "DATA['pipeline_07'] = {\n",
    "    'input': {'data': data_in_formatted, 'shape': data_in_shape},\n",
    "    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],\n",
    "    'expected': {'data': data_out_formatted, 'shape': data_out_shape}\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### export for Keras.js tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "filename = '../../test/data/pipeline/07.json'\n",
    "if not os.path.exists(os.path.dirname(filename)):\n",
    "    os.makedirs(os.path.dirname(filename))\n",
    "with open(filename, 'w') as f:\n",
    "    json.dump(DATA, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"pipeline_07\": {\"input\": {\"data\": [0.0965, 0.594443, -0.987782, 0.431322, 0.067427, -0.470888, 0.41181, 0.052982, 0.484153, 0.136201, 0.553358, -0.655578, -0.424989, 0.031041, -0.023031, 0.913774, 0.568027, -0.978941, -0.353237, -0.297017, 0.639236, 0.391994, 0.591057, -0.833079, 0.945585, 0.970605, 0.433942, -0.655561, -0.087023, -0.250158, 0.925941, -0.64485, 0.621258, -0.009644, -0.629281, -0.373595, 0.485602, -0.815495, 0.064045, -0.79823, 0.686509, 0.961811, -0.093605, -0.512734, -0.570547, 0.682538, 0.433208, 0.69226, 0.070745, -0.89601, -0.657133, -0.529907, -0.147531, 0.674429, 0.455914, -0.517776, -0.556739, 0.119107, 0.656276, -0.02947, 0.396705, 0.691676, 0.196748, 0.166465, -0.211257, -0.678947, 0.21344, -0.690575, -0.091028, 0.60432, 0.438176, 0.066163, -0.762816, -0.750232, 0.079407, 0.002266, -0.138425, 0.131461, 0.422326, 0.740444, -0.550612, 0.162234, 0.835085, 0.032045, 0.059742, 0.821414, 0.540964, -0.365625, 0.940139, -0.016265, -0.00874, -0.67179, 0.839636, 0.204161, 0.055069, -0.80054, -0.321087, -0.543786, 0.737843, 0.592256, -0.642479, 0.168902, 0.356467, -0.306618, -0.137864, -0.815771, 0.700587, 0.4349, 0.889517, -0.443741, -0.644906, -0.847551, -0.875106, -0.050461, -0.744143, 0.796467, -0.304847, 0.635756, -0.855731, 0.981174, 0.866329, -0.555073, 0.070089, -0.42501, -0.63011, -0.021348, 0.444343, -0.394568], \"shape\": [8, 8, 2]}, \"weights\": [{\"data\": [0.0965, 0.594443, -0.987782, 0.431322, 0.067427, -0.470888, 0.41181, 0.052982, 0.484153, 0.136201, 0.553358, -0.655578, -0.424989, 0.031041, -0.023031, 0.913774, 0.568027, -0.978941, -0.353237, -0.297017, 0.639236, 0.391994, 0.591057, -0.833079, 0.945585, 0.970605, 0.433942, -0.655561, -0.087023, -0.250158, 0.925941, -0.64485, 0.621258, -0.009644, -0.629281, -0.373595, 0.485602, -0.815495, 0.064045, -0.79823, 0.686509, 0.961811, -0.093605, -0.512734, -0.570547, 0.682538, 0.433208, 0.69226, 0.070745, -0.89601, -0.657133, -0.529907, -0.147531, 0.674429, 0.455914, -0.517776, -0.556739, 0.119107, 0.656276, -0.02947, 0.396705, 0.691676, 0.196748, 0.166465, -0.211257, -0.678947, 0.21344, -0.690575, -0.091028, 0.60432, 0.438176, 0.066163], \"shape\": [3, 3, 2, 4]}, {\"data\": [0.228005, 0.859479, -0.49018, 0.232871], \"shape\": [4]}], \"expected\": {\"data\": [1.86288, 1.350518, 0.0, 1.534338, 1.756351, 3.939944, 0.0, 1.450033, 2.893287, 3.939944, 0.0, 0.0, 2.893287, 2.984937, 1.910038, 0.162801, 2.387606, 3.327839, 1.910038, 0.439481, 1.86288, 1.350518, 0.0, 3.294138, 1.756351, 3.939944, 0.0, 1.450033, 1.756351, 3.939944, 0.0, 0.46946, 1.683364, 2.537879, 1.910038, 1.010561, 1.530127, 2.537879, 1.910038, 1.010561, 0.827691, 3.657351, 0.0, 3.294138, 0.8945, 1.52444, 1.271375, 0.0, 0.8945, 1.52444, 1.271375, 1.054137, 0.509853, 2.275232, 0.0, 1.054137, 1.530127, 2.275232, 0.0, 1.149238, 2.337285, 3.657351, 0.201048, 1.358881, 2.337285, 2.194357, 1.271375, 0.110681, 1.403623, 2.194357, 1.271375, 1.054137, 1.403623, 2.454154, 0.749519, 1.054137, 1.519193, 2.454154, 0.749519, 1.149238, 2.337285, 4.052855, 0.544261, 1.047269, 2.337285, 4.052855, 0.52913, 1.047269, 1.669021, 4.129228, 0.52913, 0.737423, 2.727855, 4.129228, 0.749519, 0.737423, 2.727855, 2.454154, 0.749519, 0.0], \"shape\": [5, 5, 4]}}}\n"
     ]
    }
   ],
   "source": [
    "print(json.dumps(DATA))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
